{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Novel Pose and Expression Synthesis for a Pretrained Avatar"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "os.chdir(\"..\")\n",
    "from nha.util.render import create_intrinsics_matrix\n",
    "import torch\n",
    "from nha.models.nha_optimizer import NHAOptimizer\n",
    "from nha.util.general import dict_2_device\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ckpt = \"path-to-ckpt-file\"\n",
    "tracking_results=\"path-to-tracking-results-file\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "avatar = NHAOptimizer.load_from_checkpoint(ckpt).eval().cuda()\n",
    "tr = np.load(tracking_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Synthesizing novel poses and expressions\n",
    "@torch.no_grad()\n",
    "def synthesize_novel_poses_and_expressions(expr = torch.zeros(100, dtype=torch.float),\n",
    "                                           pose = torch.zeros(15, dtype=torch.float), image_size = (512, 512)):\n",
    "    \n",
    "    img_h, img_w = image_size\n",
    "    track_h, track_w = tr['image_size']\n",
    "    fx_scale = max(track_h, track_w) * img_w / track_w\n",
    "    fy_scale = max(track_h, track_w) * img_h / track_h\n",
    "    cx_scale = img_w\n",
    "    cy_scale = img_h\n",
    "    cam_intrinsics = create_intrinsics_matrix(\n",
    "        fx=tr[\"K\"][0] * fx_scale,\n",
    "        fy=tr[\"K\"][0] * fy_scale,\n",
    "        px=tr[\"K\"][1] * cx_scale,\n",
    "        py=tr[\"K\"][2] * cy_scale,\n",
    "    ) \n",
    "\n",
    "    # creating batch with inputs to avatar\n",
    "    rest_joint_rots = avatar._flame.get_neutral_joint_rotations()\n",
    "    default_pose = torch.cat((rest_joint_rots[\"global\"], \n",
    "                              rest_joint_rots[\"neck\"], \n",
    "                              rest_joint_rots[\"jaw\"], \n",
    "                              rest_joint_rots[\"eyes\"],\n",
    "                              rest_joint_rots[\"eyes\"]\n",
    "                             ), dim=0).cpu()\n",
    "    \n",
    "    batch = dict(\n",
    "                flame_shape = torch.from_numpy(tr[\"shape\"][None]).float(),\n",
    "                flame_expr = expr[None],\n",
    "                flame_pose = (pose+default_pose)[None],\n",
    "                flame_trans = torch.from_numpy(tr[\"translation\"][[0]]).float(),\n",
    "                cam_intrinsic=cam_intrinsics[None],\n",
    "                cam_extrinsic=torch.from_numpy(tr[\"RT\"]).float()[None],\n",
    "                rgb=torch.zeros(1,3,img_h,img_w))    \n",
    "    \n",
    "    batch = dict_2_device(batch, avatar.device)\n",
    "    \n",
    "    \n",
    "    # make prediction\n",
    "    rgba = avatar.forward(batch, symmetric_rgb_range=False)\n",
    "    shaded_mesh = avatar.predict_shaded_mesh(batch)\n",
    "    \n",
    "    return rgba, shaded_mesh\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Manipulate the following parameters to control the avatar.\n",
    "\"\"\"\n",
    "\n",
    "# expression parameters\n",
    "e0 = 0\n",
    "e1 = 0\n",
    "e2 = 0\n",
    "e3 = 0\n",
    "e4 = 0\n",
    "\n",
    "# global rotation\n",
    "rot0 = 0\n",
    "rot1 = np.pi\n",
    "rot2 = 0\n",
    "\n",
    "# neck rotation\n",
    "neck0 = 0\n",
    "neck1 = 0\n",
    "neck2 = 0\n",
    "\n",
    "# jaw movement\n",
    "jaw = 0\n",
    "\n",
    "expr = torch.zeros(100, dtype=torch.float)\n",
    "pose = torch.zeros(15, dtype=torch.float)\n",
    "expr[0] = e0; expr[1] = e1; expr[2] = e2; expr[3] = e3; expr[4] = e4\n",
    "pose[0] = rot0; pose[1] = rot1;  pose[2] = rot2; pose[3] = neck0;  pose[4] = neck1; pose[5] = neck2; pose[6] = jaw\n",
    "\n",
    "\n",
    "rgba, shaded_mesh = synthesize_novel_poses_and_expressions(expr=expr, pose=pose, image_size=(512,512))\n",
    "fig, axes = plt.subplots(ncols=2, figsize=(20,10))\n",
    "axes[0].imshow(rgba[0,:3].cpu().permute(1,2,0))\n",
    "axes[1].imshow(shaded_mesh[0, :3].cpu().permute(1,2,0))\n",
    "[a.axis(\"off\") for a in axes]\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
